Embedding Cache Agent
This agent demonstrates memory-intensive workloads and warm-start performance. It loads a heavy embedding model (e.g., SentenceTransformers) into memory once during startup and reuses it for all subsequent requests.How it Works
- Startup: Loads the model (simulated 2GB model) into global memory. This happens only once when the worker starts.
- Request: Receives text input.
- Processing: Generates vector embeddings instantly using the pre-loaded model.
- Output: Returns the vector array.
Key Features
- High Memory Config: Uses
memory: 4096(4GB) to accommodate the model. - Zero Cold Starts: Subsequent requests are millisecond-fast because the model stays in memory.
- Singleton Pattern: Uses a global variable to cache the model across requests.
Usage
Copy
orpheus run embedding-cache '{"text": "Hello world"}'
Source Code
- agent.yaml
- agent.py
Copy
name: embedding-cache
runtime: python3
module: agent
entrypoint: handler
memory: 512 # CONFIGURABLE for different OOM tests
timeout: 300
# ServiceManager integration
model: tinyllama
engine: ollama # Using Ollama for CPU testing (no rate limits)
env:
- COLLECTIONS_DIR=/agent/data/collections
- MAX_COLLECTIONS=2
# Telemetry configuration - custom labels for Prometheus filtering
telemetry:
enabled: true
labels:
team: ml_platform
tier: memory_intensive
use_case: embedding
scaling:
min_workers: 1
max_workers: 5
target_utilization: 1.5
scale_up_threshold: 2.0
scale_down_threshold: 0.3
Copy
#!/usr/bin/env python3
"""
Embedding Cache Agent for Orpheus: Memory-Intensive ML Workload (OOM Testing).
This agent demonstrates Orpheus's OOM detection and recovery:
1. Real ML Memory Patterns:
- Loads sentence-transformers embedding model (~400MB)
- Loads multiple FAISS indices (~150MB per 10K docs)
- Tests actual memory pressure from ML libraries
2. OOM Detection (exit 137):
- ServiceManager detects OOM kills
- Applies exponential backoff (60s+)
- Auto-replaces worker after 3 failures
3. Memory Tracking:
- Uses psutil to track real RSS memory
- Logs memory before/after each load
- Returns actual memory stats
The agent supports actions:
- load: Load collections (triggers OOM if over limit)
- search: Search loaded collections
- status: Return memory usage and loaded collections
"""
import os
import sys
import time
import psutil
from pathlib import Path
from typing import Dict, List
# RAG dependencies (same as rag-agent)
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_ollama import OllamaLLM
from langchain_core.prompts import PromptTemplate
# Configuration
COLLECTIONS_DIR = os.getenv("COLLECTIONS_DIR", "./data/collections")
MAX_COLLECTIONS = int(os.getenv("MAX_COLLECTIONS", "2"))
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
# Model endpoint - ServiceManager injects MODEL_URL automatically
MODEL_URL = os.getenv("MODEL_URL") or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
MODEL_NAME = os.getenv("OLLAMA_MODEL", "mistral")
class EmbeddingCache:
"""Manages multiple FAISS indices with memory tracking."""
def __init__(self):
self.embedding_model = None # HuggingFaceEmbeddings
self.collections = {} # Dict[str, FAISS]
self.max_collections = MAX_COLLECTIONS
self.llm = None
self._load_times = {}
def get_memory_usage_mb(self) -> float:
"""Get current process memory usage in MB."""
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def load_embedding_model(self):
"""Load sentence-transformers embedding model (~400MB)."""
if self.embedding_model is not None:
return # Already loaded
print(f"[cache] Loading embedding model: {EMBEDDING_MODEL}", file=sys.stderr, flush=True)
print(f"[cache] This will use approximately 400MB of memory...", file=sys.stderr, flush=True)
mem_before = self.get_memory_usage_mb()
start_time = time.time()
self.embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
elapsed = time.time() - start_time
mem_after = self.get_memory_usage_mb()
mem_delta = mem_after - mem_before
print(
f"[cache] Embedding model loaded: {mem_delta:.1f}MB in {elapsed:.2f}s (total: {mem_after:.1f}MB)",
file=sys.stderr,
flush=True,
)
self._load_times["embedding_model"] = elapsed
def load_collection(self, name: str) -> bool:
"""Load a FAISS index (~150MB per 10K docs)."""
if name in self.collections:
print(f"[cache] Collection '{name}' already loaded", file=sys.stderr, flush=True)
return True
# Ensure embedding model is loaded first
if self.embedding_model is None:
self.load_embedding_model()
# Check if we're at max collections
if len(self.collections) >= self.max_collections:
print(
f"[cache] Max collections reached ({self.max_collections}), cannot load '{name}'",
file=sys.stderr,
flush=True,
)
return False
# Build path to FAISS index
collection_path = Path(COLLECTIONS_DIR) / name / "faiss_index"
if not collection_path.exists():
print(f"[cache] Collection '{name}' not found at {collection_path}", file=sys.stderr, flush=True)
return False
print(f"[cache] Loading collection '{name}' from {collection_path}", file=sys.stderr, flush=True)
print(f"[cache] This will use approximately 150MB of memory...", file=sys.stderr, flush=True)
mem_before = self.get_memory_usage_mb()
start_time = time.time()
vectorstore = FAISS.load_local(
str(collection_path),
self.embedding_model,
allow_dangerous_deserialization=True,
)
elapsed = time.time() - start_time
mem_after = self.get_memory_usage_mb()
mem_delta = mem_after - mem_before
self.collections[name] = vectorstore
self._load_times[f"collection_{name}"] = elapsed
print(
f"[cache] Collection '{name}' loaded: {mem_delta:.1f}MB in {elapsed:.2f}s (total: {mem_after:.1f}MB)",
file=sys.stderr,
flush=True,
)
return True
def load_all_available(self) -> List[str]:
"""Load all available collections up to MAX_COLLECTIONS limit."""
collections_dir = Path(COLLECTIONS_DIR)
if not collections_dir.exists():
print(f"[cache] Collections directory not found: {collections_dir}", file=sys.stderr, flush=True)
return []
# Find all collection directories
available = sorted([d.name for d in collections_dir.iterdir() if d.is_dir()])
print(f"[cache] Found {len(available)} collections: {available}", file=sys.stderr, flush=True)
loaded = []
for name in available[: self.max_collections]:
if self.load_collection(name):
loaded.append(name)
return loaded
def search_collection(self, collection_name: str, query: str, top_k: int = 3) -> dict:
"""Search a loaded collection."""
if collection_name not in self.collections:
return {
"status": "error",
"error": f"Collection '{collection_name}' not loaded. Available: {list(self.collections.keys())}",
}
vectorstore = self.collections[collection_name]
start_time = time.time()
docs = vectorstore.similarity_search(query, k=top_k)
elapsed = time.time() - start_time
return {
"status": "success",
"collection": collection_name,
"query": query,
"results": [
{"content": doc.page_content[:200] + "...", "metadata": doc.metadata} for doc in docs
],
"retrieval_ms": int(elapsed * 1000),
}
def get_status(self) -> dict:
"""Get current cache status and memory usage."""
mem_current = self.get_memory_usage_mb()
return {
"status": "success",
"memory_mb": round(mem_current, 1),
"max_collections": self.max_collections,
"collections_loaded": list(self.collections.keys()),
"embedding_model_loaded": self.embedding_model is not None,
"load_times": {k: round(v, 2) for k, v in self._load_times.items()},
}
# Global cache instance (persists across invocations)
_cache = None
def get_cache() -> EmbeddingCache:
"""Get or create the global cache instance."""
global _cache
if _cache is None:
_cache = EmbeddingCache()
return _cache
def handler(input_data: dict) -> dict:
"""
Orpheus agent handler for embedding cache operations.
Input:
{
"action": "load" | "search" | "status",
"collection": "collection_1", # for search
"query": "search query", # for search
"top_k": 3 # for search (optional)
}
Output:
{
"status": "success" | "error",
"action": "load" | "search" | "status",
"collections_loaded": [...],
"memory_mb": 687.3,
...
}
"""
try:
cache = get_cache()
action = input_data.get("action", "load")
print(f"[cache] Action: {action}", file=sys.stderr, flush=True)
print(f"[cache] Current memory: {cache.get_memory_usage_mb():.1f}MB", file=sys.stderr, flush=True)
if action == "load":
# Load collections - this is where OOM can occur
start_time = time.time()
loaded = cache.load_all_available()
elapsed = time.time() - start_time
status = cache.get_status()
status.update(
{
"action": "load",
"load_time_sec": round(elapsed, 2),
"collections_loaded": loaded,
}
)
print(f"[cache] Load complete: {len(loaded)} collections, {status['memory_mb']}MB", file=sys.stderr, flush=True)
return status
elif action == "search":
collection = input_data.get("collection")
query = input_data.get("query")
top_k = input_data.get("top_k", 3)
if not collection or not query:
return {
"status": "error",
"action": "search",
"error": "Missing 'collection' or 'query' field",
}
result = cache.search_collection(collection, query, top_k)
result["action"] = "search"
result["memory_mb"] = round(cache.get_memory_usage_mb(), 1)
return result
elif action == "status":
status = cache.get_status()
status["action"] = "status"
return status
else:
return {
"status": "error",
"error": f"Unknown action: {action}. Use 'load', 'search', or 'status'",
}
except Exception as e:
print(f"[cache] ERROR: {e}", file=sys.stderr, flush=True)
return {
"status": "error",
"action": input_data.get("action", "unknown"),
"error": str(e),
}

